{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 配置GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 超参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 5\n",
    "num_classes = 10\n",
    "batch_size = 100\n",
    "learning_rate = 0.001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 皮一下~用Fashion_MNIST数据集~\n",
    "train_dataset = torchvision.datasets.FashionMNIST(root='../../data/fashion',\n",
    "                                          train=True,\n",
    "                                          transform=transforms.ToTensor(),\n",
    "                                          download=True)\n",
    "test_dataset = torchvision.datasets.FashionMNIST(root='../../data/fashion',\n",
    "                                                train=False,\n",
    "                                                transform=transforms.ToTensor())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据生成器DataLoader\n",
    "train_loader = torch.utils.data.DataLoader(dataset=train_dataset,\n",
    "                                          batch_size=batch_size,\n",
    "                                          shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(dataset=test_dataset,\n",
    "                                         batch_size=batch_size,\n",
    "                                         shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convolutional neural network (two convolutional layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 有两个卷积层的卷积神经网络\n",
    "class ConvNet(nn.Module):\n",
    "    def __init__(self, num_classes=10):\n",
    "        super(ConvNet, self).__init__()\n",
    "        self.layer1 = nn.Sequential(\n",
    "            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),\n",
    "            nn.BatchNorm2d(16),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        )\n",
    "        self.layer2 = nn.Sequential(\n",
    "            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),\n",
    "            nn.BatchNorm2d(32),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        )\n",
    "        self.fc = nn.Linear(7*7*32, num_classes)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        out = self.layer1(x)\n",
    "        out = self.layer2(out)\n",
    "        out = out.reshape(out.size(0), -1)\n",
    "        out = self.fc(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ConvNet(num_classes).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 损失 和 优化器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/5], Step[100/600], Loss: 0.4434\n",
      "Epoch [1/5], Step[200/600], Loss: 0.2976\n",
      "Epoch [1/5], Step[300/600], Loss: 0.3106\n",
      "Epoch [1/5], Step[400/600], Loss: 0.4081\n",
      "Epoch [1/5], Step[500/600], Loss: 0.3521\n",
      "Epoch [1/5], Step[600/600], Loss: 0.2117\n",
      "Epoch [2/5], Step[100/600], Loss: 0.3241\n",
      "Epoch [2/5], Step[200/600], Loss: 0.3203\n",
      "Epoch [2/5], Step[300/600], Loss: 0.2124\n",
      "Epoch [2/5], Step[400/600], Loss: 0.2234\n",
      "Epoch [2/5], Step[500/600], Loss: 0.1273\n",
      "Epoch [2/5], Step[600/600], Loss: 0.2325\n",
      "Epoch [3/5], Step[100/600], Loss: 0.1918\n",
      "Epoch [3/5], Step[200/600], Loss: 0.2580\n",
      "Epoch [3/5], Step[300/600], Loss: 0.1905\n",
      "Epoch [3/5], Step[400/600], Loss: 0.2110\n",
      "Epoch [3/5], Step[500/600], Loss: 0.2038\n",
      "Epoch [3/5], Step[600/600], Loss: 0.3344\n",
      "Epoch [4/5], Step[100/600], Loss: 0.1478\n",
      "Epoch [4/5], Step[200/600], Loss: 0.3134\n",
      "Epoch [4/5], Step[300/600], Loss: 0.2849\n",
      "Epoch [4/5], Step[400/600], Loss: 0.2735\n",
      "Epoch [4/5], Step[500/600], Loss: 0.4289\n",
      "Epoch [4/5], Step[600/600], Loss: 0.2682\n",
      "Epoch [5/5], Step[100/600], Loss: 0.2806\n",
      "Epoch [5/5], Step[200/600], Loss: 0.2875\n",
      "Epoch [5/5], Step[300/600], Loss: 0.1472\n",
      "Epoch [5/5], Step[400/600], Loss: 0.2305\n",
      "Epoch [5/5], Step[500/600], Loss: 0.2297\n",
      "Epoch [5/5], Step[600/600], Loss: 0.1776\n"
     ]
    }
   ],
   "source": [
    "total_step = len(train_loader)\n",
    "for epoch in range(num_epochs):\n",
    "    for i, (images, labels) in enumerate(train_loader):\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "        #前向传播\n",
    "        outputs = model(images)\n",
    "        loss = criterion(outputs, labels)\n",
    "        #反向传播\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if(i+1)%100 == 0:\n",
    "            print('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Accuracy of the model on the 1000 test images: 86.0%\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for images, labels in test_loader:\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "        outputs = model(images)\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += 1\n",
    "        correct = (predicted == labels).sum().item()\n",
    "        \n",
    "    print('Test Accuracy of the model on the 1000 test images: {}%'.format(100 * correct/total))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 保存模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), 'model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow-gpu",
   "language": "python",
   "name": "tensorflow-gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
